//===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass legalizes vector operations so they can be lowered to ArmSME.
//
// Note: In the context of this pass 'tile' always refers to an SME tile.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "arm-sme-vector-legalization"

namespace mlir::arm_sme {
#define GEN_PASS_DEF_VECTORLEGALIZATION
#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
} // namespace mlir::arm_sme

using namespace mlir;
using namespace mlir::arm_sme;

namespace {

//===----------------------------------------------------------------------===//
// Decomposition of vector operations larger than an SME tile
//===----------------------------------------------------------------------===//

// Common match failure reasons.
static constexpr StringLiteral kMatchFailureNotSMETileTypeMultiple(
    "op vector size is not multiple of SME tiles");
static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
    "op mask is unsupported for legalization/decomposition");
static constexpr StringLiteral
    kMatchFailureNonPermutationMap("op affine map is not a permutation");
static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
    "expected transpose from illegal type to legal type");

/// An SMESubTile represents a single SME-sized sub-tile from decomposing a
/// larger vector type. The (`row`, `col`) are the position of the tile in the
/// original vector type. For example for an [8]x[8] tile with four [4]x[4]
/// sub-tiles, we would have:
///
///           8 x vscale
/// ┌─────────────┬─────────────┐
/// │(0,0)        │(0,4)        │
/// │             │             │
/// ├─────────────┼─────────────┤ 8 x vscale
/// │(4,0)        │(4,4)        │
/// │             │             │
/// └─────────────┴─────────────┘
struct SMESubTile {
  // Note: The units of (row, col) are vscale (as SME tiles are scalable).
  int row{0};
  int col{0};
  // The SME tile type.
  VectorType type;
};

/// Adds a constant elementwise scalable offset to `indices` (which are of equal
/// length). For example, in the 2D case this would return:
// { indices[0] + offset[0] * vscale, indices[1] + offset[1] *  vscale }
SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
                                                Location loc,
                                                ValueRange indices,
                                                ArrayRef<int> scalableOffsets) {
  auto vscale = vector::VectorScaleOp::create(builder, loc);
  return llvm::map_to_vector(
      llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value {
        auto [index, base] = pair;
        auto offset = arith::MulIOp::create(
            builder, loc, arith::ConstantIndexOp::create(builder, loc, base),
            vscale);
        return arith::AddIOp::create(builder, loc, index, offset);
      });
}

/// Adjusts `indices` (e.g. from a load/store) for a larger vector type to
/// indices for one of the SME sub-tiles it will decompose into.
///
/// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
/// indices for each tile would need to be adjusted as follows:
///
/// initial indices = [a,b], inital size = 8x8, target size = 4x4
/// ┌─────────────┬─────────────┐
/// │[a,b]        │[a,b+4]      │
/// │             │             │
/// ├─────────────┼─────────────┤
/// │[a+4,b]      │[a+4,b+4]    │
/// │             │             │
/// └─────────────┴─────────────┘
SmallVector<Value, 2> getSMESubTileIndices(OpBuilder &builder, Location loc,
                                           ValueRange indices,
                                           SMESubTile smeTile) {
  return addConstantScalableOffset(builder, loc, indices,
                                   {smeTile.row, smeTile.col});
}

/// Returns true if `mask` is generated by an operation that can be decomposed
/// for SME. Currently, that is just no mask, or vector.create_mask.
/// TODO: Add support for vector.constant_mask once required for SME.
bool isSupportedMaskOp(Value mask) {
  return !mask || mask.getDefiningOp<vector::CreateMaskOp>();
}

/// Extracts a mask for an SME sub-tile from the mask of a larger vector type.
Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
                     SMESubTile smeTile) {
  assert(isSupportedMaskOp(mask));
  if (!mask)
    return Value{};
  auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
  // The operands of `vector.create_mask` (from a 2D perspective) are the
  // coordinates where the mask ends. So we subtract where this tile starts,
  // from the mask operands to get the parameters for this sub-tile.
  auto smeTileMaskDims = addConstantScalableOffset(
      builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col});
  auto smeTileCreateMask = vector::CreateMaskOp::create(
      builder, loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims);
  return smeTileCreateMask.getResult();
}

/// Constructs an iterator that returns each SME tile (with coordinates)
/// contained within a VectorType. For example, if decomposing an [8]x[8] into
/// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0),
/// (4, 4).
auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
                         VectorType smeTileType,
                         bool transposeIndices = false) {
  return llvm::map_range(
      StaticTileOffsetRange(
          type.getShape(),
          {std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
           std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
      [=](auto indices) {
        int row = int(indices[0]);
        int col = int(indices[1]);
        if (transposeIndices)
          std::swap(row, col);
        return SMESubTile{row, col, smeTileType};
      });
}

/// Returns the number of SME tiles that fit into the (2D-scalable) vector type
/// `type`.
int getNumberOfSMETilesForVectorType(VectorType type) {
  assert(isMultipleOfSMETileVectorType(type) &&
         "`type` not multiple of SME tiles");
  int64_t vectorRows = type.getDimSize(0);
  int64_t vectorCols = type.getDimSize(1);
  auto elementType = type.getElementType();
  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
  return (vectorRows * vectorCols) / (minNumElts * minNumElts);
}

/// Legalize `arith.constant dense<value>` splat operations to fit within SME
/// tiles by decomposing them into tile-sized operations.
struct LegalizeArithConstantOpsByDecomposition
    : public OpConversionPattern<arith::ConstantOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto vectorType = dyn_cast<VectorType>(constantOp.getType());
    auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
    if (!vectorType || !denseAttr || !denseAttr.isSplat())
      return failure();

    if (!isMultipleOfSMETileVectorType(vectorType))
      return rewriter.notifyMatchFailure(constantOp,
                                         kMatchFailureNotSMETileTypeMultiple);

    auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
    auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
    auto tileSplat = arith::ConstantOp::create(
        rewriter, constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
    SmallVector<Value> repl(tileCount, tileSplat);
    rewriter.replaceOpWithMultiple(constantOp, {repl});

    return success();
  }
};

/// Legalize `vector.outerproduct` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeVectorOuterProductOpsByDecomposition
    : public OpConversionPattern<vector::OuterProductOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(vector::OuterProductOp outerProductOp,
                  OneToNOpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto vectorType = outerProductOp.getResultVectorType();
    if (!isMultipleOfSMETileVectorType(vectorType))
      return rewriter.notifyMatchFailure(outerProductOp,
                                         kMatchFailureNotSMETileTypeMultiple);

    Value mask;
    Operation *rootOp = outerProductOp;
    auto loc = outerProductOp.getLoc();
    if (outerProductOp.isMasked()) {
      auto maskOp = outerProductOp.getMaskingOp();
      mask = maskOp.getMask();
      rootOp = maskOp;
      rewriter.setInsertionPoint(rootOp);
    }

    if (!isSupportedMaskOp(mask))
      return rewriter.notifyMatchFailure(outerProductOp,
                                         kMatchFailureUnsupportedMaskOp);

    ValueRange accSMETiles = adaptor.getAcc();
    auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
    VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0);

    SmallVector<Value> resultSMETiles;
    for (auto [index, smeTile] : llvm::enumerate(
             decomposeToSMETiles(rewriter, vectorType, smeTileType))) {

      auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
      auto lhs = vector::ScalableExtractOp::create(
          rewriter, loc, sliceType, outerProductOp.getLhs(), smeTile.row);
      auto rhs = vector::ScalableExtractOp::create(
          rewriter, loc, sliceType, outerProductOp.getRhs(), smeTile.col);
      auto smeOuterProduct = vector::OuterProductOp::create(
          rewriter, loc, smeTileType, lhs, rhs,
          !accSMETiles.empty() ? accSMETiles[index] : Value{},
          outerProductOp.getKind());

      auto *maskedOuterProduct =
          vector::maskOperation(rewriter, smeOuterProduct, smeMask);
      resultSMETiles.push_back(maskedOuterProduct->getResult(0));
    }

    rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles});
    return success();
  }
};

// Workaround for `vector.mask`. We want to match on `vector.outerproduct` (to
// get the help of the type conversion), but doing so results in the type
// conversion adding target materializations in the `vector.mask` region
// (invalid). This pattern matches on `vector.mask` then calls into the
// `vector.outerproduct` pattern to work around this issue.
struct LegalizeMaskedVectorOuterProductOpsByDecomposition
    : public OpConversionPattern<vector::MaskOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
            maskOp.getMaskableOp())) {
      LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
                                                           getContext());
      return static_cast<RewritePattern &>(pattern).matchAndRewrite(
          outerProductOp, rewriter);
    }
    return failure();
  }
};

/// Legalize `vector.transfer_read` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeTransferReadOpsByDecomposition
    : public OpConversionPattern<vector::TransferReadOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto vectorType = readOp.getVectorType();
    if (!isMultipleOfSMETileVectorType(vectorType))
      return rewriter.notifyMatchFailure(readOp,
                                         kMatchFailureNotSMETileTypeMultiple);

    auto mask = readOp.getMask();
    if (!isSupportedMaskOp(mask))
      return rewriter.notifyMatchFailure(readOp,
                                         kMatchFailureUnsupportedMaskOp);

    auto permutationMap = readOp.getPermutationMap();
    if (!permutationMap.isPermutation())
      return rewriter.notifyMatchFailure(readOp,
                                         kMatchFailureNonPermutationMap);

    // Note: For 2D vector types the only non-identity permutation is a simple
    // transpose [1, 0].
    bool transposed = !permutationMap.isIdentity();

    auto loc = readOp.getLoc();
    auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());

    SmallVector<Value> resultSMETiles;
    for (SMESubTile smeTile :
         decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
      auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
      auto smeRead = vector::TransferReadOp::create(
          rewriter, loc, smeTileType, readOp.getBase(),
          getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile),
          readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask,
          readOp.getInBoundsAttr());
      resultSMETiles.push_back(smeRead);
    }

    rewriter.replaceOpWithMultiple(readOp, {resultSMETiles});
    return success();
  }
};

/// Legalize `vector.transfer_write` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeTransferWriteOpsByDecomposition
    : public OpConversionPattern<vector::TransferWriteOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto vectorType = writeOp.getVectorType();
    if (!isMultipleOfSMETileVectorType(vectorType))
      return rewriter.notifyMatchFailure(writeOp,
                                         kMatchFailureNotSMETileTypeMultiple);

    auto mask = writeOp.getMask();
    if (!isSupportedMaskOp(mask))
      return rewriter.notifyMatchFailure(writeOp,
                                         kMatchFailureUnsupportedMaskOp);

    auto permutationMap = writeOp.getPermutationMap();
    if (!permutationMap.isPermutation())
      return rewriter.notifyMatchFailure(writeOp,
                                         kMatchFailureNonPermutationMap);

    // Note: For 2D vector types the only non-identity permutation is a simple
    // transpose [1, 0].
    bool transposed = !permutationMap.isIdentity();

    auto loc = writeOp.getLoc();
    auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
    auto inputSMETiles = adaptor.getValueToStore();

    Value destTensorOrMemref = writeOp.getBase();
    for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
             rewriter, vectorType, smeTileType, transposed))) {
      auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);
      auto smeWrite = vector::TransferWriteOp::create(
          rewriter, loc, inputSMETiles[index], destTensorOrMemref,
          getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile),
          writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr());
      if (writeOp.hasPureTensorSemantics())
        destTensorOrMemref = smeWrite.getResult();
    }

    if (writeOp.hasPureTensorSemantics())
      rewriter.replaceOp(writeOp, destTensorOrMemref);
    else
      rewriter.eraseOp(writeOp);

    return success();
  }
};

/// Legalize a multi-tile transfer_write as a single store loop. This is done as
/// part of type decomposition as at this level we know each tile write is
/// disjoint, but that information is lost after decomposition (without analysis
/// to reconstruct it).
///
/// Example (pseudo-MLIR):
///
/// ```
/// vector.transfer_write %vector, %dest[%y, %x], %mask
///   : vector<[16]x[8]xi16>, memref<?x?xi16>
/// ```
/// Is rewritten to:
/// ```
/// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
///   %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
///     : vector<[8]xi1> from vector<[16]x[8]xi1>           |
///   %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
///     : vector<[8]xi16> from vector<[8]x[8]xi16>          |
///   vector.transfer_write %upper_slice,                   |
///     %dest[%slice_idx + %y, %x], %upper_slice_mask       |
///     : vector<[8]xi16>, memref<?x?xi16>                  ┘
///   %lower_slice_idx = %slice_idx + %c8_vscale                 ─┐
///   %lower_slice_mask = vector.extract %mask[%lower_slice_idx]  |
///     : vector<[8]xi1> from vector<[16]x[8]xi1>                 |
///   %lower_slice = vector.extract %lower_tile[%slice_idx]       |- Store lower
///     : vector<[8]xi16> from vector<[8]x[8]xi16>                |  tile
///   vector.transfer_write %lower_slice,                         |
///     %dest[%lower_slice_idx + %y, %x], %lower_slice_mask       |
///     : vector<[8]xi16>, memref<?x?xi16>                        ┘
/// }
/// ```
struct LegalizeMultiTileTransferWriteAsStoreLoop
    : public OpConversionPattern<vector::TransferWriteOp> {
  using OpConversionPattern::OpConversionPattern;

  LogicalResult
  matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    if (writeOp.hasPureTensorSemantics())
      return rewriter.notifyMatchFailure(
          writeOp, "TODO: tensor semantics are unsupported");

    auto permutationMap = writeOp.getPermutationMap();
    if (!permutationMap.isPermutation())
      return rewriter.notifyMatchFailure(writeOp,
                                         kMatchFailureNonPermutationMap);

    bool transposed = !permutationMap.isIdentity();
    if (transposed)
      return rewriter.notifyMatchFailure(writeOp,
                                         "TODO: transpose unsupported");

    auto vectorType = writeOp.getVectorType();
    if (!isMultipleOfSMETileVectorType(vectorType))
      return rewriter.notifyMatchFailure(writeOp,
                                         kMatchFailureNotSMETileTypeMultiple);

    // Note: We also disallow masks where any dimension is > 16 because that
    // prevents the masking from being lowered to use arm_sve.psel.
    auto mask = writeOp.getMask();
    if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
                                              vectorType.getDimSize(1) > 16)))
      return rewriter.notifyMatchFailure(writeOp,
                                         kMatchFailureUnsupportedMaskOp);

    auto loc = writeOp.getLoc();
    auto createVscaleMultiple =
        vector::makeVscaleConstantBuilder(rewriter, loc);

    // Get SME tile and slice types.
    auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
    auto minTileSlices = smeTileType.getDimSize(0);
    VectorType sliceMaskType =
        VectorType::get(minTileSlices, rewriter.getI1Type(), true);

    // Create loop over all tile slices.
    auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
    auto upperBound = createVscaleMultiple(minTileSlices);
    auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
    auto storeLoop =
        scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step);
    rewriter.setInsertionPointToStart(storeLoop.getBody());

    // For each sub-tile of the multi-tile `vectorType`.
    auto inputSMETiles = adaptor.getValueToStore();
    auto tileSliceIndex = storeLoop.getInductionVar();
    for (auto [index, smeTile] : llvm::enumerate(
             decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
      // The coordinates of the tile within `vectorType`.
      auto tileRow = createVscaleMultiple(smeTile.row);
      auto tileCol = createVscaleMultiple(smeTile.col);

      // The current slice of `vectorType` we are processing.
      auto sliceIndex =
          arith::AddIOp::create(rewriter, loc, tileRow, tileSliceIndex);

      // Where in the destination memref the current slice will be stored.
      auto storeRow = arith::AddIOp::create(rewriter, loc, sliceIndex,
                                            writeOp.getIndices()[0]);
      auto storeCol = arith::AddIOp::create(rewriter, loc, tileCol,
                                            writeOp.getIndices()[1]);

      // Extract the mask for the current slice.
      Value sliceMask = nullptr;
      if (mask) {
        sliceMask = vector::ExtractOp::create(rewriter, loc, mask,
                                              OpFoldResult(sliceIndex));
        if (sliceMaskType != sliceMask.getType())
          sliceMask = vector::ScalableExtractOp::create(
              rewriter, loc, sliceMaskType, sliceMask, smeTile.col);
      }

      // Extract and store the current slice.
      Value tile = inputSMETiles[index];
      auto slice =
          vector::ExtractOp::create(rewriter, loc, tile, tileSliceIndex);
      vector::TransferWriteOp::create(
          rewriter, loc, slice, writeOp.getBase(),
          ValueRange{storeRow, storeCol},
          AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
          sliceMask,
          rewriter.getBoolArrayAttr(
              ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
    }

    rewriter.eraseOp(writeOp);
    return success();
  }
};

//===----------------------------------------------------------------------===//
// ArmSME-specific fixup canonicalizations/folds
//===----------------------------------------------------------------------===//

/// Folds an extract from a 3D `vector.create_mask` (which is a vector of
/// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
/// necessary for the mask to be lowered to ArmSME.
///
/// Example:
///
///  BEFORE:
///  ```mlir
///  %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
///  %subMask = vector.extract %mask[2]
///          : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
///  ```
///
///  AFTER:
///  ```mlir
///  %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
///  %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
///  %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
///  ```
struct FoldExtractFromVectorOfSMELikeCreateMasks
    : public OpRewritePattern<vector::ExtractOp> {
  using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;

  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
                                PatternRewriter &rewriter) const override {
    auto loc = extractOp.getLoc();
    auto createMaskOp =
        extractOp.getSource().getDefiningOp<vector::CreateMaskOp>();
    if (!createMaskOp)
      return rewriter.notifyMatchFailure(
          extractOp, "extract not from vector.create_mask op");

    VectorType extractedMaskType =
        llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
    if (!extractedMaskType)
      return rewriter.notifyMatchFailure(extractOp,
                                         "extracted type is not a vector type");

    auto numScalable = extractedMaskType.getNumScalableDims();
    if (numScalable != 2)
      return rewriter.notifyMatchFailure(
          extractOp, "expected extracted type to be an SME-like mask");

    // TODO: Support multiple extraction indices.
    if (extractOp.getStaticPosition().size() != 1)
      return rewriter.notifyMatchFailure(
          extractOp, "only a single extraction index is supported");

    auto frontMaskDim = createMaskOp.getOperand(0);
    if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
      return rewriter.notifyMatchFailure(
          extractOp,
          "constant vector.create_masks dims should be folded elsewhere");

    auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
    auto extractionIndex = getValueOrCreateConstantIndexOp(
        rewriter, loc, extractOp.getMixedPosition()[0]);
    auto extractionInTrueRegion = arith::CmpIOp::create(
        rewriter, loc, rewriter.getI1Type(), arith::CmpIPredicate::slt,
        extractionIndex, frontMaskDim);
    auto newMaskFrontDim =
        arith::SelectOp::create(rewriter, loc, extractionInTrueRegion,
                                createMaskOp.getOperand(1), zero);

    rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
        extractOp, extractedMaskType,
        ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
    return success();
  }
};

/// A vector type where no fixed dimension comes after a scalable dimension.
bool isLegalVectorType(VectorType vType) {
  bool seenFixedDim = false;
  for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
    seenFixedDim |= !scalableFlag;
    if (seenFixedDim && scalableFlag)
      return false;
  }
  return true;
}

/// Lifts an illegal vector.transpose and vector.transfer_read to a
/// memref.subview + memref.transpose, followed by a legal read.
///
/// 'Illegal' here means a leading scalable dimension and a fixed trailing
/// dimension, which has no valid lowering.
///
/// The memref.transpose is metadata-only transpose that produces a strided
/// memref, which eventually becomes a loop reading individual elements.
///
/// Example:
///
///  BEFORE:
///  ```mlir
///  %illegalRead = vector.transfer_read %memref[%a, %b]
///                  : memref<?x?xf32>, vector<[8]x4xf32>
///  %legalType = vector.transpose %illegalRead, [1, 0]
///                  : vector<[8]x4xf32> to vector<4x[8]xf32>
///  ```
///
///  AFTER:
///  ```mlir
///  %readSubview = memref.subview %memref[%a, %b] [%c8_vscale, %c4] [%c1, %c1]
///                  : memref<?x?xf32> to memref<?x?xf32>
///  %transpose = memref.transpose %readSubview (d0, d1) -> (d1, d0)
///                  : memref<?x?xf32> to memref<?x?xf32>
///  %legalType = vector.transfer_read %transpose[%c0, %c0]
///                  : memref<?x?xf32>, vector<4x[8]xf32>
///  ```
struct LiftIllegalVectorTransposeToMemory
    : public OpRewritePattern<vector::TransposeOp> {
  using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;

  static Value getExtensionSource(Operation *op) {
    if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
      return op->getOperand(0);
    return {};
  }

  LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
                                PatternRewriter &rewriter) const override {
    auto sourceType = transposeOp.getSourceVectorType();
    auto resultType = transposeOp.getResultVectorType();
    if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
      return rewriter.notifyMatchFailure(transposeOp,
                                         kMatchFailureNotIllegalToLegal);

    // Look through extend for transfer_read.
    Value maybeRead = transposeOp.getVector();
    auto *transposeSourceOp = maybeRead.getDefiningOp();
    Operation *extendOp = nullptr;
    if (Value extendSource = getExtensionSource(transposeSourceOp)) {
      maybeRead = extendSource;
      extendOp = transposeSourceOp;
    }

    auto illegalRead = maybeRead.getDefiningOp<vector::TransferReadOp>();
    if (!illegalRead)
      return rewriter.notifyMatchFailure(
          transposeOp,
          "expected source to be (possibly extended) transfer_read");

    if (!illegalRead.getPermutationMap().isIdentity())
      return rewriter.notifyMatchFailure(
          illegalRead, "expected read to have identity permutation map");

    auto loc = transposeOp.getLoc();
    auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
    auto one = arith::ConstantIndexOp::create(rewriter, loc, 1);

    // Create a subview that matches the size of the illegal read vector type.
    auto readType = illegalRead.getVectorType();
    auto readSizes = llvm::map_to_vector(
        llvm::zip_equal(readType.getShape(), readType.getScalableDims()),
        [&](auto dim) -> Value {
          auto [size, isScalable] = dim;
          auto dimSize = arith::ConstantIndexOp::create(rewriter, loc, size);
          if (!isScalable)
            return dimSize;
          auto vscale = vector::VectorScaleOp::create(rewriter, loc);
          return arith::MulIOp::create(rewriter, loc, vscale, dimSize);
        });
    SmallVector<Value> strides(readType.getRank(), Value(one));
    auto readSubview =
        memref::SubViewOp::create(rewriter, loc, illegalRead.getBase(),
                                  illegalRead.getIndices(), readSizes, strides);

    // Apply the transpose to all values/attributes of the transfer_read:
    // - The mask
    Value mask = illegalRead.getMask();
    if (mask) {
      // Note: The transpose for the mask should fold into the
      // vector.create_mask/constant_mask op, which will then become legal.
      mask = vector::TransposeOp::create(rewriter, loc, mask,
                                         transposeOp.getPermutation());
    }
    // - The source memref
    mlir::AffineMap transposeMap = AffineMap::getPermutationMap(
        transposeOp.getPermutation(), getContext());
    auto transposedSubview = memref::TransposeOp::create(
        rewriter, loc, readSubview, AffineMapAttr::get(transposeMap));
    ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr();
    // - The `in_bounds` attribute
    if (inBoundsAttr) {
      SmallVector<Attribute> inBoundsValues(inBoundsAttr.begin(),
                                            inBoundsAttr.end());
      applyPermutationToVector(inBoundsValues, transposeOp.getPermutation());
      inBoundsAttr = rewriter.getArrayAttr(inBoundsValues);
    }

    VectorType legalReadType = resultType.clone(readType.getElementType());
    // Note: The indices are all zero as the subview is already offset.
    SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero);
    auto legalRead = vector::TransferReadOp::create(
        rewriter, loc, legalReadType, transposedSubview, readIndices,
        illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask,
        inBoundsAttr);

    // Replace the transpose with the new read, extending the result if
    // necessary.
    rewriter.replaceOp(transposeOp, [&]() -> Operation * {
      if (extendOp)
        return rewriter.create(loc, extendOp->getName().getIdentifier(),
                               Value(legalRead), resultType);
      return legalRead;
    }());

    return success();
  }
};

/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
/// the ZA state. This workaround rewrite to support these transposes when ZA is
/// available.
///
/// Example:
///
///  BEFORE:
///  ```mlir
///  %transpose = vector.transpose %vec, [1, 0]
///     : vector<2x[4]xf32> to vector<[4]x2xf32>
///  vector.transfer_write %transpose, %dest[%y, %x]
///     : vector<[4]x2xf32>,  memref<?x?xf32>
///  ```
///
///  AFTER:
///  ```mlir
///   %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
///   %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
///   %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
///   %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
///   %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
///   %c4_vscale = arith.muli %vscale, %c4 : index
///   %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
///   vector.transfer_write %4, %dest[%y, %x], %mask
///      {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
///      : vector<[4]x[4]xf32>, memref<?x?xf32>
///  ```
///
/// Values larger than a single tile are supported via decomposition.
struct LowerIllegalTransposeStoreViaZA
    : public OpRewritePattern<vector::TransferWriteOp> {
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
                                PatternRewriter &rewriter) const override {
    if (!isSupportedMaskOp(writeOp.getMask()))
      return rewriter.notifyMatchFailure(writeOp,
                                         kMatchFailureUnsupportedMaskOp);

    auto permutationMap = writeOp.getPermutationMap();
    if (!permutationMap.isIdentity())
      return rewriter.notifyMatchFailure(writeOp,
                                         kMatchFailureNonPermutationMap);

    auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
    if (!transposeOp)
      return failure();

    auto sourceType = transposeOp.getSourceVectorType();
    auto resultType = transposeOp.getResultVectorType();

    if (resultType.getRank() != 2)
      return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2");

    if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
      return rewriter.notifyMatchFailure(
          transposeOp, "not illegal/unsupported SVE transpose");

    auto smeTileType = getSMETileTypeForElement(resultType.getElementType());
    VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0);

    if (sourceType.getDimSize(0) <= 1 ||
        sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
      return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");

    auto loc = writeOp.getLoc();
    auto createVscaleMultiple =
        vector::makeVscaleConstantBuilder(rewriter, loc);

    auto transposeMap = AffineMapAttr::get(
        AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext()));

    // Note: We need to use `get_tile` as there's no vector-level `undef`.
    Value undefTile = arm_sme::GetTileOp::create(rewriter, loc, smeTileType);
    Value destTensorOrMemref = writeOp.getBase();
    auto numSlicesPerTile =
        std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
    auto numSlices =
        arith::ConstantIndexOp::create(rewriter, loc, numSlicesPerTile);
    for (auto [index, smeTile] : llvm::enumerate(
             decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
      // 1. _Deliberately_ drop a scalable dimension and insert a fixed number
      // of slices from the source type into the SME tile. Without checking
      // vscale (and emitting multiple implementations) we can't make use of the
      // rows of the tile after 1*vscale rows.
      Value tile = undefTile;
      for (int d = 0; d < numSlicesPerTile; ++d) {
        Value vector =
            vector::ExtractOp::create(rewriter, loc, transposeOp.getVector(),
                                      rewriter.getIndexAttr(d + smeTile.row));
        if (vector.getType() != smeSliceType) {
          vector = vector::ScalableExtractOp::create(
              rewriter, loc, smeSliceType, vector, smeTile.col);
        }
        tile = vector::InsertOp::create(rewriter, loc, vector, tile, d);
      }

      // 2. Transpose the tile position.
      auto transposedRow = createVscaleMultiple(smeTile.col);
      auto transposedCol =
          arith::ConstantIndexOp::create(rewriter, loc, smeTile.row);

      // 3. Compute mask for tile store.
      Value maskRows;
      Value maskCols;
      if (auto mask = writeOp.getMask()) {
        auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
        maskRows = arith::SubIOp::create(
            rewriter, loc, createMask.getOperand(0), transposedRow);
        maskCols = arith::SubIOp::create(
            rewriter, loc, createMask.getOperand(1), transposedCol);
        maskCols = index::MinSOp::create(rewriter, loc, maskCols, numSlices);
      } else {
        maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
        maskCols = numSlices;
      }
      auto subMask = vector::CreateMaskOp::create(
          rewriter, loc, smeTileType.clone(rewriter.getI1Type()),
          ValueRange{maskRows, maskCols});

      // 4. Emit a transposed tile write.
      auto writeIndices = writeOp.getIndices();
      Value destRow =
          arith::AddIOp::create(rewriter, loc, transposedRow, writeIndices[0]);
      Value destCol =
          arith::AddIOp::create(rewriter, loc, transposedCol, writeIndices[1]);
      auto smeWrite = vector::TransferWriteOp::create(
          rewriter, loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
          transposeMap, subMask, writeOp.getInBounds());

      if (writeOp.hasPureTensorSemantics())
        destTensorOrMemref = smeWrite.getResult();
    }

    if (writeOp.hasPureTensorSemantics())
      rewriter.replaceOp(writeOp, destTensorOrMemref);
    else
      rewriter.eraseOp(writeOp);

    return success();
  }
};

/// Lower `vector.transfer_read` of a scalable column to `scf::for`
///
/// Lowers a "read" of a scalable column from a MemRef for which there is no
/// hardware pperation that we could use to a loop over the rows to read and
/// loads one element at a time.
///
///  BEFORE:
///  ```
///  %res = vector.transfer_read %mem[%a, %b] (...)
///    : memref<?x?xf32>, vector<[4]x1xf32>
///  ```
///
///  AFTER:
///  ```
///    %cst = arith.constant (...) : vector<[4]xf32>
///    %vscale = vector.vscale
///    %c4_vscale = arith.muli %vscale, %c4 : index
///    %scf = scf.for %lb = %c0 to %c4_vscale step %c1 iter_args(%arg4 = %cst)
///      -> (vector<[4]xf32>) {
///
///        %load = memref.load %mem[%arg3 + %a, %b] : memref<?x?xf32>
///        %vec = vector.insert %load, %cst [%arg3] : f32 into vector<[4]xf32>
///        scf.yield %vec : vector<[4]xf32>
///    }
///    %res = vector.shape_cast %scf : vector<[4]xf32> to vector<[4]x1xf32>
///  ```
///
///  TODO: This transformation isn't specific to SME - move it to the SVE
///  dialect.
///  TODO: Check the in_bounds attribute and generate vector.maskedload if
///  required.
struct LowerColumnTransferReadToLoops
    : public OpRewritePattern<vector::TransferReadOp> {
  using OpRewritePattern::OpRewritePattern;

  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
                                PatternRewriter &rewriter) const override {
    // NOTE: This is a fairly low-level transformation, so we shouldn't be
    // adding support for Tensors without good rationale.
    if (readOp.hasPureTensorSemantics())
      return rewriter.notifyMatchFailure(
          readOp, "Tensor semantics are unsupported (either bufferize or "
                  "extend this pattern)");

    auto resType = readOp.getVectorType();

    if (resType.getRank() != 2)
      return rewriter.notifyMatchFailure(readOp,
                                         "Only 2D vectors are supported!");

    if (resType.getShape()[1] != 1)
      return rewriter.notifyMatchFailure(
          readOp, "The trailing output dim is != 1 (not supported ATM)");

    if (!resType.getScalableDims()[0] || resType.getScalableDims()[1])
      return rewriter.notifyMatchFailure(
          readOp, "Expected the leading dim to be scalable and the trailing "
                  "dim to be fixed.");

    // Create new result type - similar to the original vector with the
    // trailing unit dim collapsed.
    int64_t numRows = resType.getShape()[0];
    VectorType newResType = VectorType::get(numRows, resType.getElementType(),
                                            /*scalableDims=*/{true});

    // Create a loop over all rows and load one element at a time.
    auto loc = readOp.getLoc();
    auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
    auto createVscaleMultiple =
        vector::makeVscaleConstantBuilder(rewriter, loc);
    auto upperBound = createVscaleMultiple(numRows);
    auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
    Value init = arith::ConstantOp::create(
        rewriter, loc, newResType, DenseElementsAttr::get(newResType, 0.0f));

    scf::ForOp loadLoop;
    {
      OpBuilder::InsertionGuard g(rewriter);
      loadLoop = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step,
                                    ValueRange{init});
      rewriter.setInsertionPointToStart(loadLoop.getBody());

      auto tileSliceIndex = loadLoop.getInductionVar();

      auto idx0 = arith::AddIOp::create(rewriter, loc, tileSliceIndex,
                                        readOp.getIndices()[0]);
      auto idx1 = readOp.getIndices()[1];

      Value scalar = memref::LoadOp::create(rewriter, loc, readOp.getBase(),
                                            SmallVector<Value>({idx0, idx1}));

      Operation *updateInit = vector::InsertOp::create(
          rewriter, loc, scalar, loadLoop.getRegionIterArg(0), tileSliceIndex);

      scf::YieldOp::create(rewriter, loc, updateInit->getResult(0));
    }

    // The read operation has been "legalized", but since the original result
    // type was a 2D vector, we need to cast before returning the result. This
    // ShapeCast should cancel-out with some other ShapeCast (i.e. it's a
    // no-op).
    auto sc = vector::ShapeCastOp::create(
        rewriter, loc, readOp.getResult().getType(), loadLoop.getResult(0));

    rewriter.replaceOp(readOp, sc);

    return success();
  }
};

struct VectorLegalizationPass
    : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
  void runOnOperation() override {
    auto *context = &getContext();
    TypeConverter converter;
    RewritePatternSet patterns(context);
    converter.addConversion([](Type type) { return type; });
    converter.addConversion(
        [](VectorType vectorType,
           SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
          if (!isMultipleOfSMETileVectorType(vectorType))
            return std::nullopt;
          auto smeTileCount = getNumberOfSMETilesForVectorType(vectorType);
          auto smeTileType =
              getSMETileTypeForElement(vectorType.getElementType());
          types = SmallVector<Type>(smeTileCount, smeTileType);
          return success();
        });

    // Apply preprocessing patterns.
    RewritePatternSet rewritePatterns(context);
    rewritePatterns
        .add<FoldExtractFromVectorOfSMELikeCreateMasks,
             LowerColumnTransferReadToLoops, LiftIllegalVectorTransposeToMemory,
             LowerIllegalTransposeStoreViaZA>(context);
    if (failed(
            applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
      return signalPassFailure();

    // Note: These two patterns are added with a high benefit to ensure:
    //  - Masked outer products are handled before unmasked ones
    //  - Multi-tile writes are lowered as a store loop (if possible)
    patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
                 LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
                                                            /*benefit=*/1024);
    patterns.add<LegalizeArithConstantOpsByDecomposition,
                 LegalizeVectorOuterProductOpsByDecomposition,
                 LegalizeTransferReadOpsByDecomposition,
                 LegalizeTransferWriteOpsByDecomposition>(converter, context);
    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
                                                                   converter);
    populateCallOpTypeConversionPattern(patterns, converter);
    populateReturnOpTypeConversionPattern(patterns, converter);
    scf::populateSCFStructuralTypeConversions(converter, patterns);

    ConversionTarget target(getContext());
    target.markUnknownOpDynamicallyLegal(
        [&](Operation *op) { return converter.isLegal(op); });
    target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
      return converter.isSignatureLegal(op.getFunctionType());
    });
    if (failed(applyPartialConversion(getOperation(), target,
                                      std::move(patterns))))
      return signalPassFailure();
  }
};

} // namespace

std::unique_ptr<Pass> mlir::arm_sme::createVectorLegalizationPass() {
  return std::make_unique<VectorLegalizationPass>();
}
